{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import yaml\n",
    "from glob import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    from yaml import CLoader as Loader, CDumper as Dumper\n",
    "except ImportError:\n",
    "    from yaml import Loader, Dumper\n",
    "config = yaml.load(open(\"ensemble_config.yaml\", \"r\"), Loader = yaml.FullLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "keyword = config[\"keyword\"]\n",
    "res_dir = config[\"res_dir\"]\n",
    "run_ids = config[\"run_ids\"]\n",
    "\n",
    "run_id2path_list = {}\n",
    "for run_id in run_ids:\n",
    "    path_pattern = os.path.join(res_dir, run_id, \"*{}*.json\".format(keyword))\n",
    "    if run_id not in run_id2path_list:\n",
    "        run_id2path_list[run_id] = []\n",
    "    for path in glob(path_pattern):\n",
    "        run_id2path_list[run_id].append(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# only last k\n",
    "k = config[\"last_k_res\"]\n",
    "for run_id, path_list in run_id2path_list.items():\n",
    "    run_id2path_list[run_id] = path_list[-k:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loading res: 100%|██████████| 5/5 [00:00<00:00, 12.10it/s]\n"
     ]
    }
   ],
   "source": [
    "res_total = []\n",
    "total_path_list = []\n",
    "for path_list in run_id2path_list.values():\n",
    "    total_path_list.extend(path_list)\n",
    "for path in tqdm(total_path_list, desc = \"loading res\"):\n",
    "    res_total.extend([json.loads(line) for line in open(path, \"r\", encoding = \"utf-8\")])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loading into list: 100%|██████████| 2472/2472 [00:00<00:00, 12763.84it/s]\n"
     ]
    }
   ],
   "source": [
    "id_list, text_list, \\\n",
    "subject_list, predicate_list, object_list, \\\n",
    "subj_tok_span_list, subj_char_span_list, \\\n",
    "obj_tok_span_list, obj_char_span_list = [], [], [], [], [], [], [], [], []\n",
    "\n",
    "for sample in tqdm(res_total, desc = \"loading into list\"):\n",
    "    for rel in sample[\"relation_list\"]:\n",
    "        id_list.append(sample[\"id\"])\n",
    "        text_list.append(sample[\"text\"])\n",
    "        subject_list.append(rel[\"subject\"])\n",
    "        object_list.append(rel[\"object\"])\n",
    "        predicate_list.append(rel[\"predicate\"])\n",
    "        subj_tok_span_list.append(\"{},{}\".format(*rel[\"subj_tok_span\"]))\n",
    "        obj_tok_span_list.append(\"{},{}\".format(*rel[\"obj_tok_span\"]))\n",
    "        subj_char_span_list.append(\"{},{}\".format(*rel[\"subj_char_span\"]))\n",
    "        obj_char_span_list.append(\"{},{}\".format(*rel[\"obj_char_span\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ensemble_df = pd.DataFrame({\n",
    "    \"id\": id_list,\n",
    "    \"text\": text_list,\n",
    "    \"subject\": subject_list,\n",
    "    \"predicate\": predicate_list,\n",
    "    \"object\": object_list,\n",
    "    \"subj_tok_span\": subj_tok_span_list,\n",
    "    \"subj_char_span\": subj_char_span_list,\n",
    "    \"obj_tok_span\": obj_tok_span_list,\n",
    "    \"obj_char_span\": obj_char_span_list,\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>text</th>\n",
       "      <th>subject</th>\n",
       "      <th>predicate</th>\n",
       "      <th>object</th>\n",
       "      <th>subj_tok_span</th>\n",
       "      <th>subj_char_span</th>\n",
       "      <th>obj_tok_span</th>\n",
       "      <th>obj_char_span</th>\n",
       "      <th>num</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>val_1331</td>\n",
       "      <td>Improving access to analgesic drugs for patien...</td>\n",
       "      <td>analgesic drugs</td>\n",
       "      <td>PositivelyRegulates</td>\n",
       "      <td>cancer</td>\n",
       "      <td>5,9</td>\n",
       "      <td>20,35</td>\n",
       "      <td>12,13</td>\n",
       "      <td>54,60</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>val_1331</td>\n",
       "      <td>Improving access to analgesic drugs for patien...</td>\n",
       "      <td>analgesic drugs</td>\n",
       "      <td>PositivelyRegulates</td>\n",
       "      <td>cancer</td>\n",
       "      <td>5,9</td>\n",
       "      <td>20,35</td>\n",
       "      <td>125,126</td>\n",
       "      <td>604,610</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>val_1331</td>\n",
       "      <td>Improving access to analgesic drugs for patien...</td>\n",
       "      <td>analgesic drugs</td>\n",
       "      <td>PositivelyRegulates</td>\n",
       "      <td>cancer</td>\n",
       "      <td>5,9</td>\n",
       "      <td>20,35</td>\n",
       "      <td>23,24</td>\n",
       "      <td>109,115</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>val_1331</td>\n",
       "      <td>Improving access to analgesic drugs for patien...</td>\n",
       "      <td>analgesic drugs</td>\n",
       "      <td>PositivelyRegulates</td>\n",
       "      <td>pain</td>\n",
       "      <td>5,9</td>\n",
       "      <td>20,35</td>\n",
       "      <td>131,132</td>\n",
       "      <td>642,646</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>val_1331</td>\n",
       "      <td>Improving access to analgesic drugs for patien...</td>\n",
       "      <td>analgesic drugs</td>\n",
       "      <td>PositivelyRegulates</td>\n",
       "      <td>pain</td>\n",
       "      <td>5,9</td>\n",
       "      <td>20,35</td>\n",
       "      <td>89,90</td>\n",
       "      <td>453,457</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         id                                               text  \\\n",
       "0  val_1331  Improving access to analgesic drugs for patien...   \n",
       "1  val_1331  Improving access to analgesic drugs for patien...   \n",
       "2  val_1331  Improving access to analgesic drugs for patien...   \n",
       "3  val_1331  Improving access to analgesic drugs for patien...   \n",
       "4  val_1331  Improving access to analgesic drugs for patien...   \n",
       "\n",
       "           subject            predicate  object subj_tok_span subj_char_span  \\\n",
       "0  analgesic drugs  PositivelyRegulates  cancer           5,9          20,35   \n",
       "1  analgesic drugs  PositivelyRegulates  cancer           5,9          20,35   \n",
       "2  analgesic drugs  PositivelyRegulates  cancer           5,9          20,35   \n",
       "3  analgesic drugs  PositivelyRegulates    pain           5,9          20,35   \n",
       "4  analgesic drugs  PositivelyRegulates    pain           5,9          20,35   \n",
       "\n",
       "  obj_tok_span obj_char_span  num  \n",
       "0        12,13         54,60    5  \n",
       "1      125,126       604,610    3  \n",
       "2        23,24       109,115    5  \n",
       "3      131,132       642,646    1  \n",
       "4        89,90       453,457    2  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ensemble_df_w_duplicate_num = ensemble_df.groupby(ensemble_df.columns.tolist(), as_index = False).size().reset_index().rename(columns={0: 'num'})\n",
    "ensemble_df_w_duplicate_num.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2647\n"
     ]
    }
   ],
   "source": [
    "vote_threshold = config[\"vote_threshold\"]\n",
    "ensemble_res_df = ensemble_df_w_duplicate_num[ensemble_df_w_duplicate_num.num >= vote_threshold]\n",
    "print(len(ensemble_res_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2647/2647 [00:01<00:00, 1473.55it/s]\n",
      "100%|██████████| 369/369 [00:00<00:00, 371328.74it/s]\n"
     ]
    }
   ],
   "source": [
    "id2text, id2relations = {}, {}\n",
    "for idx in tqdm(range(len(ensemble_res_df))):\n",
    "    row = ensemble_res_df.iloc[idx]\n",
    "    id2text[row.id] = row.text\n",
    "    if row.id not in id2relations:\n",
    "        id2relations[row.id] = []\n",
    "\n",
    "    subj_char_span = row.subj_char_span.split(\",\")\n",
    "    subj_tok_span = row.subj_tok_span.split(\",\")\n",
    "    obj_char_span = row.obj_char_span.split(\",\")\n",
    "    obj_tok_span = row.obj_tok_span.split(\",\")\n",
    "    \n",
    "    id2relations[row.id].append({\n",
    "        \"subject\": row.subject,\n",
    "        \"predicate\": row.predicate,\n",
    "        \"object\": row.object,\n",
    "        \"subj_char_span\": [int(subj_char_span[0]), int(subj_char_span[1])],\n",
    "        \"subj_tok_span\": [int(subj_tok_span[0]), int(subj_tok_span[1])],\n",
    "        \"obj_char_span\": [int(obj_char_span[0]), int(obj_char_span[1])],\n",
    "        \"obj_tok_span\": [int(obj_tok_span[0]), int(obj_tok_span[1])]\n",
    "    })\n",
    "\n",
    "emsemble_res = []\n",
    "for idx, text in tqdm(id2text.items()):\n",
    "    emsemble_res.append({\n",
    "        \"text\": text,\n",
    "        \"id\": idx,\n",
    "        \"relation_list\": id2relations[idx],\n",
    "    })"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 369/369 [00:00<00:00, 9030.38it/s]\n"
     ]
    }
   ],
   "source": [
    "ensemble_res_dir = config[\"ensemble_res_dir\"]\n",
    "if not os.path.exists(ensemble_res_dir):\n",
    "    os.makedirs(ensemble_res_dir)\n",
    "    \n",
    "file_num = len(glob(os.path.join(ensemble_res_dir, \"*ensemble*.json\")))\n",
    "save_path = os.path.join(ensemble_res_dir, \"ensemble_res_{}.json\".format(file_num))\n",
    "\n",
    "with open(save_path, \"w\", encoding = \"utf-8\") as file_out:\n",
    "    for sample in tqdm(emsemble_res):\n",
    "        json_line = json.dumps(sample, ensure_ascii = False)\n",
    "        file_out.write(\"{}\\n\".format(json_line))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
